package com.zboot.comm.datasource.sharding.service;

import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import com.zboot.comm.datasource.sharding.appender.ShardingDataNodeAppender;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;
import org.apache.shardingsphere.infra.datanode.DataNode;
import org.apache.shardingsphere.sharding.rule.ShardingRule;
import org.apache.shardingsphere.sharding.rule.TableRule;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.stereotype.Service;

import javax.annotation.Resource;
import java.lang.reflect.Field;
import java.lang.reflect.Modifier;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.Statement;
import java.util.*;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * 执行一些DDL语句
 * 刷新共享配置等
 */
@Slf4j
@Service
public class ShardingTableRuleService {

    @Value("${spring.shardingsphere.datasource.master-datasource.driver-class-name}")
    private String driverClassName;
    @Value("${spring.shardingsphere.datasource.master-datasource.url}")
    private String url;
    @Value("${spring.shardingsphere.datasource.master-datasource.username}")
    private String username;
    @Value("${spring.shardingsphere.datasource.master-datasource.password}")
    private String password;
    @Autowired
    private ShardingRule shardingRule;
    @Resource(name = "shardingDataNodeAppenderAgent")
    ShardingDataNodeAppender shardingDataNodeAppender;

    /**
     * 获取已存在的表名
     * 这里使用原生jdbc，而不是mybatis
     * 因为mybatis的查询会被shardingsphere拦截，导致查询结果错误
     * @return
     */
    @SneakyThrows
    public List<String> getExistsTableNames(String logicTableName) {
        Connection connection = null;
        ResultSet rs = null;
        Statement statement = null;
        try {
            connection = DriverManager.getConnection(url, username, password);
            String sql = "show tables like '" + logicTableName + "%'";
            statement = connection.createStatement();
            rs = statement.executeQuery(sql);
            List<String> set = new ArrayList<>();
            while (rs.next()) {
                set.add(rs.getString(1));
            }
            return set;
        }finally {
            if(rs!=null) {
                rs.close();
            }
            if(statement!=null) {
                statement.close();
            }
            if(connection!=null) {
                connection.close();
            }
        }
    }

    /**
     * 创建表
     * @param addTableNames
     * @param logicTableName
     */
    @SneakyThrows
    public void createTablesIfNotExists(Collection<String> addTableNames, String logicTableName) {
        List<String> sqls = new ArrayList<>();
        addTableNames.forEach(actualName -> {
            sqls.add("CREATE TABLE IF NOT EXISTS "+actualName+" like "+logicTableName);
        });
        //批量执行建表
        if(sqls!=null && sqls.size()>0) {
            Connection connection = null;
            Statement statement = null;
            try {
                connection = DriverManager.getConnection(url, username, password);
                statement = connection.createStatement();
                connection.setAutoCommit(false);
                for (String sql : sqls) {
                    log.debug("---> 创建表："+sql);
                    statement.execute(sql);
                }
                connection.commit();
            }catch (Exception e) {
                connection.rollback();
            }finally {
                if(statement!=null) {
                    statement.close();
                }
                if(connection!=null) {
                    connection.close();
                }
            }
        }
    }

    @SneakyThrows
    public void freshActualDataNodes(String dataSourceName, TableRule tableRule, List<DataNode> newDataNodes) {
        Set<String> actualTables = Sets.newHashSet();
        Map<DataNode, Integer> dataNodeIndexMap = Maps.newHashMap();
        AtomicInteger index = new AtomicInteger(0);
        newDataNodes.forEach(dataNode -> {
            actualTables.add(dataNode.getTableName());//构建一个表名set
            dataNodeIndexMap.put(dataNode, index.getAndIncrement());//一个节点索引map
        });
        // 官方没有提供修改配置的接口，只有通过反射赋值
        // 动态刷新：actualDataNodesField
        Field actualDataNodesField = TableRule.class.getDeclaredField("actualDataNodes");
        Field modifiersField = Field.class.getDeclaredField("modifiers");
        modifiersField.setAccessible(true);
        modifiersField.setInt(actualDataNodesField, actualDataNodesField.getModifiers() & ~Modifier.FINAL);
        actualDataNodesField.setAccessible(true);
        actualDataNodesField.set(tableRule, newDataNodes);
        // 动态刷新：actualTablesField
        Field actualTablesField = TableRule.class.getDeclaredField("actualTables");
        actualTablesField.setAccessible(true);
        actualTablesField.set(tableRule, actualTables);
        // 动态刷新：dataNodeIndexMapField
        Field dataNodeIndexMapField = TableRule.class.getDeclaredField("dataNodeIndexMap");
        dataNodeIndexMapField.setAccessible(true);
        dataNodeIndexMapField.set(tableRule, dataNodeIndexMap);
        // 动态刷新：datasourceToTablesMapField
        Map<String, Collection<String>> datasourceToTablesMap = Maps.newHashMap();
        datasourceToTablesMap.put(dataSourceName, actualTables);
        Field datasourceToTablesMapField = TableRule.class.getDeclaredField("datasourceToTablesMap");
        datasourceToTablesMapField.setAccessible(true);
        datasourceToTablesMapField.set(tableRule, datasourceToTablesMap);
    }

    public void refreshActualDataNodes() {
        log.debug("===== 动态刷新 actualDataNodes =====");
        shardingRule.getTableRules().forEach((logicName, tableRule) -> {
            log.info("===> 刷新表 " + logicName);
            if(shardingDataNodeAppender.needAppend(tableRule)) {
                shardingDataNodeAppender.append(tableRule);
            }
        });

    }
}
